{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import argparse\n",
    "import datetime\n",
    "import math\n",
    "import pickle\n",
    "\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from utils.autoaugment import CIFAR10Policy\n",
    "\n",
    "import torch\n",
    "import torch.utils.data as data\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "from torch.autograd import Variable\n",
    "from torchvision import datasets\n",
    "from torch.utils.data.sampler import SubsetRandomSampler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.NonBayesianModels import conv_init\n",
    "from utils.NonBayesianModels.AlexNet import AlexNet\n",
    "from utils.NonBayesianModels.LeNet import LeNet\n",
    "from utils.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_type = 'alexnet'\n",
    "dataset = 'CIFAR10'\n",
    "outputs = 10\n",
    "inputs = 3\n",
    "n_epochs = 20\n",
    "lr = 0.001\n",
    "resize=32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper Parameter settings\n",
    "use_cuda = torch.cuda.is_available()\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of subprocesses to use for data loading\n",
    "num_workers = 0\n",
    "# how many samples per batch to load\n",
    "batch_size = 64\n",
    "# percentage of training set to use as validation\n",
    "valid_size = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# convert data to a normalized torch.FloatTensor\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "# choose the training and test datasets\n",
    "train_data = datasets.CIFAR10('data', train=True,\n",
    "                              download=True, transform=transform)\n",
    "test_data = datasets.CIFAR10('data', train=False,\n",
    "                             download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# obtain training indices that will be used for validation\n",
    "num_train = len(train_data)\n",
    "indices = list(range(num_train))\n",
    "np.random.shuffle(indices)\n",
    "split = int(np.floor(valid_size * num_train))\n",
    "train_idx, valid_idx = indices[split:], indices[:split]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define samplers for obtaining training and validation batches\n",
    "train_sampler = SubsetRandomSampler(train_idx)\n",
    "valid_sampler = SubsetRandomSampler(valid_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare data loaders (combine dataset and sampler)\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,\n",
    "    sampler=train_sampler, num_workers=num_workers)\n",
    "valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, \n",
    "    sampler=valid_sampler, num_workers=num_workers)\n",
    "test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, \n",
    "    num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify the image classes\n",
    "classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n",
    "           'dog', 'frog', 'horse', 'ship', 'truck']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AlexNet(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(5, 5))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): Dropout(p=0.5)\n",
      "    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (4): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (5): ReLU(inplace)\n",
      "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (7): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): ReLU(inplace)\n",
      "    (9): Dropout(p=0.5)\n",
      "    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): ReLU(inplace)\n",
      "    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (13): ReLU(inplace)\n",
      "    (14): Dropout(p=0.5)\n",
      "    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (classifier): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Architecture\n",
    "if (net_type == 'lenet'):\n",
    "    net = LeNet(outputs,inputs)\n",
    "elif (net_type == 'alexnet'):\n",
    "    net = AlexNet(outputs,inputs)\n",
    "elif (net_type == '3conv3fc'):\n",
    "        net = ThreeConvThreeFC(outputs,inputs)\n",
    "else:\n",
    "    print('Error : Network should be either [LeNet / AlexNet / 3Conv3FC')\n",
    "\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move tensors to GPU if CUDA is available\n",
    "if use_cuda:\n",
    "    net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify loss function (categorical cross-entropy)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# specify optimizer\n",
    "optimizer = optim.Adam(net.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'model_alexnet_CIFAR10_frequentist.pt'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckpt_name = f'model_{net_type}_{dataset}_frequentist.pt'\n",
    "ckpt_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 \tTraining Loss: 1.431397 \tValidation Loss: 0.327709\n",
      "Validation loss decreased (inf --> 0.327709).  Saving model ...\n",
      "Epoch: 2 \tTraining Loss: 1.196846 \tValidation Loss: 0.303414\n",
      "Validation loss decreased (0.327709 --> 0.303414).  Saving model ...\n",
      "Epoch: 3 \tTraining Loss: 1.109582 \tValidation Loss: 0.284857\n",
      "Validation loss decreased (0.303414 --> 0.284857).  Saving model ...\n",
      "Epoch: 4 \tTraining Loss: 1.058778 \tValidation Loss: 0.282861\n",
      "Validation loss decreased (0.284857 --> 0.282861).  Saving model ...\n",
      "Epoch: 5 \tTraining Loss: 1.015377 \tValidation Loss: 0.272893\n",
      "Validation loss decreased (0.282861 --> 0.272893).  Saving model ...\n",
      "Epoch: 6 \tTraining Loss: 0.985682 \tValidation Loss: 0.264794\n",
      "Validation loss decreased (0.272893 --> 0.264794).  Saving model ...\n",
      "Epoch: 7 \tTraining Loss: 0.961130 \tValidation Loss: 0.257850\n",
      "Validation loss decreased (0.264794 --> 0.257850).  Saving model ...\n",
      "Epoch: 8 \tTraining Loss: 0.942639 \tValidation Loss: 0.253482\n",
      "Validation loss decreased (0.257850 --> 0.253482).  Saving model ...\n",
      "Epoch: 9 \tTraining Loss: 0.921052 \tValidation Loss: 0.251446\n",
      "Validation loss decreased (0.253482 --> 0.251446).  Saving model ...\n",
      "Epoch: 10 \tTraining Loss: 0.892646 \tValidation Loss: 0.244815\n",
      "Validation loss decreased (0.251446 --> 0.244815).  Saving model ...\n",
      "Epoch: 11 \tTraining Loss: 0.884392 \tValidation Loss: 0.250876\n",
      "Epoch: 12 \tTraining Loss: 0.864065 \tValidation Loss: 0.239296\n",
      "Validation loss decreased (0.244815 --> 0.239296).  Saving model ...\n",
      "Epoch: 13 \tTraining Loss: 0.844942 \tValidation Loss: 0.239256\n",
      "Validation loss decreased (0.239296 --> 0.239256).  Saving model ...\n",
      "Epoch: 14 \tTraining Loss: 0.840974 \tValidation Loss: 0.238582\n",
      "Validation loss decreased (0.239256 --> 0.238582).  Saving model ...\n",
      "Epoch: 15 \tTraining Loss: 0.829902 \tValidation Loss: 0.234766\n",
      "Validation loss decreased (0.238582 --> 0.234766).  Saving model ...\n",
      "Epoch: 16 \tTraining Loss: 0.814025 \tValidation Loss: 0.230650\n",
      "Validation loss decreased (0.234766 --> 0.230650).  Saving model ...\n",
      "Epoch: 17 \tTraining Loss: 0.806663 \tValidation Loss: 0.227011\n",
      "Validation loss decreased (0.230650 --> 0.227011).  Saving model ...\n",
      "Epoch: 18 \tTraining Loss: 0.792850 \tValidation Loss: 0.224854\n",
      "Validation loss decreased (0.227011 --> 0.224854).  Saving model ...\n",
      "Epoch: 19 \tTraining Loss: 0.787480 \tValidation Loss: 0.229408\n",
      "Epoch: 20 \tTraining Loss: 0.782817 \tValidation Loss: 0.232852\n",
      "CPU times: user 2min 49s, sys: 10.9 s, total: 3min\n",
      "Wall time: 3min\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "valid_loss_min = np.Inf # track change in validation loss\n",
    "\n",
    "for epoch in range(1, n_epochs+1):\n",
    "\n",
    "    # keep track of training and validation loss\n",
    "    train_loss = 0.0\n",
    "    valid_loss = 0.0\n",
    "    \n",
    "    ###################\n",
    "    # train the model #\n",
    "    ###################\n",
    "    net.train()\n",
    "    for data, target in train_loader:\n",
    "        # move tensors to GPU if CUDA is available\n",
    "        if use_cuda:\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "        # clear the gradients of all optimized variables\n",
    "        optimizer.zero_grad()\n",
    "        # forward pass: compute predicted outputs by passing inputs to the model\n",
    "        output = net(data)\n",
    "        # calculate the batch loss\n",
    "        loss = criterion(output, target)\n",
    "        # backward pass: compute gradient of the loss with respect to model parameters\n",
    "        loss.backward()\n",
    "        # perform a single optimization step (parameter update)\n",
    "        optimizer.step()\n",
    "        # update training loss\n",
    "        train_loss += loss.item()*data.size(0)\n",
    "        \n",
    "    ######################    \n",
    "    # validate the model #\n",
    "    ######################\n",
    "    net.eval()\n",
    "    for data, target in valid_loader:\n",
    "        # move tensors to GPU if CUDA is available\n",
    "        if use_cuda:\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "        # forward pass: compute predicted outputs by passing inputs to the model\n",
    "        output = net(data)\n",
    "        # calculate the batch loss\n",
    "        loss = criterion(output, target)\n",
    "        # update average validation loss \n",
    "        valid_loss += loss.item()*data.size(0)\n",
    "    \n",
    "    # calculate average losses\n",
    "    train_loss = train_loss/len(train_loader.dataset)\n",
    "    valid_loss = valid_loss/len(valid_loader.dataset)\n",
    "        \n",
    "    # print training/validation statistics \n",
    "    print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n",
    "        epoch, train_loss, valid_loss))\n",
    "    \n",
    "    # save model if validation loss has decreased\n",
    "    if valid_loss <= valid_loss_min:\n",
    "        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(\n",
    "        valid_loss_min,\n",
    "        valid_loss))\n",
    "        torch.save(net.state_dict(), ckpt_name)\n",
    "        valid_loss_min = valid_loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.load_state_dict(torch.load(ckpt_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 1.126139\n",
      "\n",
      "Test Accuracy of airplane: 67% (672/1000)\n",
      "Test Accuracy of automobile: 71% (717/1000)\n",
      "Test Accuracy of  bird: 47% (476/1000)\n",
      "Test Accuracy of   cat: 47% (478/1000)\n",
      "Test Accuracy of  deer: 55% (550/1000)\n",
      "Test Accuracy of   dog: 49% (497/1000)\n",
      "Test Accuracy of  frog: 82% (822/1000)\n",
      "Test Accuracy of horse: 66% (667/1000)\n",
      "Test Accuracy of  ship: 79% (796/1000)\n",
      "Test Accuracy of truck: 72% (723/1000)\n",
      "\n",
      "Test Accuracy (Overall): 63% (6398/10000)\n",
      "CPU times: user 1.49 s, sys: 79.8 ms, total: 1.57 s\n",
      "Wall time: 1.57 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# track test loss# track  \n",
    "test_loss = 0.0\n",
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "\n",
    "net.eval()\n",
    "# iterate over test data\n",
    "for data, target in test_loader:\n",
    "    # move tensors to GPU if CUDA is available\n",
    "    if use_cuda:\n",
    "        data, target = data.cuda(), target.cuda()\n",
    "    # forward pass: compute predicted outputs by passing inputs to the model\n",
    "    output = net(data)\n",
    "    # calculate the batch loss\n",
    "    loss = criterion(output, target)\n",
    "    # update test loss \n",
    "    test_loss += loss.item()*data.size(0)\n",
    "    # convert output probabilities to predicted class\n",
    "    _, pred = torch.max(output, 1)    \n",
    "    # compare predictions to true label\n",
    "    correct_tensor = pred.eq(target.data.view_as(pred))\n",
    "    correct = np.squeeze(correct_tensor.numpy()) if not use_cuda else np.squeeze(correct_tensor.cpu().numpy())\n",
    "    # calculate test accuracy for each object class\n",
    "    for i in range(batch_size):\n",
    "        if i >= target.data.shape[0]: # batch_size could be greater than left number of images\n",
    "            break\n",
    "        label = target.data[i]\n",
    "        class_correct[label] += correct[i].item()\n",
    "        class_total[label] += 1\n",
    "\n",
    "# average test loss\n",
    "test_loss = test_loss/len(test_loader.dataset)\n",
    "print('Test Loss: {:.6f}\\n'.format(test_loss))\n",
    "\n",
    "for i in range(10):\n",
    "    if class_total[i] > 0:\n",
    "        print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n",
    "            classes[i], 100 * class_correct[i] / class_total[i],\n",
    "            np.sum(class_correct[i]), np.sum(class_total[i])))\n",
    "    else:\n",
    "        print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n",
    "\n",
    "print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n",
    "    100. * np.sum(class_correct) / np.sum(class_total),\n",
    "    np.sum(class_correct), np.sum(class_total)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
